# Utility ----

# Plot color palette
plot_color_palette <- function(input_cols) {
  
  col_data <- tibble(color = input_cols) %>%
    mutate(color =  fct_inorder(color))
  
  res <- col_data %>%
    ggplot(aes(x = "color", fill = color)) +
    geom_bar() +
    scale_fill_manual(values = input_cols) +
    theme_void()
  
  res
}

# Create color palette
create_gradient <- function(cols_in, n = NULL) {
  
  if (is.null(n)) {
    n <- length(cols_in)
  }
  
  colorRampPalette(cols_in)(n)
}

create_col_fun <- function(cols_in) {
  
  function(n = NULL) {
    create_gradient(cols_in, n)
  }
}

# Pull items from list of vectors
pull_nest_vec <- function(list_in, idx) {
  res <- map_chr(list_in, ~ .x[[idx]])
  
  res
}

# Capitalize first character without modifying other characters
str_to_title_v2 <- function(string, exclude = "cell") {
  
  cap_first_chr <- function(string, exclude) {
    chrs <- string %>%
      str_split(pattern = "") %>%
      unlist()
    
    if (any(chrs %in% LETTERS) || string == exclude) {
      return(string)
    }
    
    chrs[1] <- str_to_upper(chrs[1])
    
    res <- chrs %>%
      reduce(str_c)
    
    res
  }
  
  res <- string %>%
    map_chr(~ {
      .x %>%
        str_split(pattern = " ") %>%
        unlist() %>%
        map_chr(cap_first_chr, exclude = exclude) %>%
        reduce(str_c, sep = " ")
    })
  
  res
}

# Set colors
set_cols <- function(types_in, cols_in, other_cols) {
  
  types_in <- types_in[!types_in %in% names(other_cols)]
  cols_in <- cols_in[!cols_in %in% other_cols]
  
  names(cols_in) <- types_in
  cols_in <- cols_in[!is.na(names(cols_in))]
  
  res <- c(cols_in, other_cols)
  
  res
}

# Set subtype colors
set_type_cols <- function(type_in, sobjs_in, type_key, type_column = "subtype",
                          cols_in, other_cols) {
  
  sobjs_in <- sobjs_in[type_key[names(sobjs_in)] == type_in]
  
  res <- sobjs_in %>%
    map(~ {
      .x@meta.data %>%
        pull(type_column) %>%
        unique()
    }) %>%
    reduce(c) %>%
    unique() %>%
    set_cols(
      cols_in = cols_in,
      other_cols = other_cols
    )
  
  res
}


# Processing ----

# Import matrices and create Seurat object
create_sobj <- function(matrix_dir, proj_name = "SeuratProject", hash_ids = NULL, adt_count_min = 0,
                        gene_min = 250, gene_max = 5000, mito_max = 15, mt_str = "^mt-", ...) {
  
  # Load matrices
  mat_list <- Read10X(matrix_dir)
  rna_mat <- mat_list
  
  # Create Seurat object using gene expression data
  if (is_list(mat_list)) {
    rna_mat <- mat_list[["Gene Expression"]]
  }
  
  res <- rna_mat %>%
    CreateSeuratObject(
      project = proj_name,
      min.cells = 5
    )
  
  # Add antibody capture data to Seurat object
  if (is_list(mat_list)) {
    adt_mat <- res %>%
      colnames() %>%
      mat_list[["Antibody Capture"]][, .]
    
    adt_mat <- adt_mat[rowSums(as.matrix(adt_mat)) >= adt_count_min, ]
    
    res[["ADT"]] <- CreateAssayObject(adt_mat)
  }
  
  # If hash.ids are given divide into ADT and HTO assays
  if (!is_null(hash_ids[[1]])) {
    res <- res %>%
      add_HTO_assay(
        ADT_mat  = adt_mat,
        HTO_list = hash_ids
      ) %>%
      NormalizeData(
        assay = "HTO",
        normalization.method = "CLR"
      ) %>%
      HTODemux(...)
    
  } else {
    res@meta.data <- res@meta.data %>%
      rownames_to_column("cell_ids") %>%
      mutate(hash.ID = names(hash_ids)) %>%
      column_to_rownames("cell_ids")
  }
  
  # Calculate percentage of mitochondrial reads
  res <- res %>%
    PercentageFeatureSet(
      pattern  = mt_str, 
      col.name = "Percent_mito"
    )
  
  # Add QC classifications to meta.data
  res@meta.data <- res@meta.data %>%
    rownames_to_column("cell_ids") %>%
    mutate(qc_class = "Pass filters") %>%
    column_to_rownames("cell_ids")
  
  if (!is.null(hash_ids[[1]])) {
    res@meta.data <- res@meta.data %>%
      rownames_to_column("cell_ids") %>%
      mutate(qc_class = HTO_classification.global) %>%
      column_to_rownames("cell_ids")
  }
  
  res@meta.data <- res@meta.data %>%
    rownames_to_column("cell_id") %>%
    mutate(
      qc_class = str_replace(qc_class, "Singlet", "Pass filters"),
      qc_class = ifelse(nFeature_RNA < gene_min, "Low gene count", qc_class),
      qc_class = ifelse(nFeature_RNA > gene_max, "High gene count", qc_class),
      qc_class = ifelse(Percent_mito > mito_max, "High mito reads", qc_class)
    ) %>%
    column_to_rownames("cell_id")
  
  res
}

# Filter and normalize matrices, find variable features
norm_sobj <- function(sobj_in, adt_assay = "ADT", regress_vars = NULL) {
  
  # Normalize counts
  res <- sobj_in %>%
    subset(qc_class == "Pass filters") %>%
    NormalizeData(normalization.method = "LogNormalize")
  
  # Score cell cycle genes
  s.genes <- cc.genes$s.genes %>%
    str_to_title()

  g2m.genes <- cc.genes$g2m.genes %>%
    str_to_title()

  res <- res %>%
    CellCycleScoring(
      s.features   = s.genes,
      g2m.features = g2m.genes
    )
  
  # Scale data
  res <- res %>%
    FindVariableFeatures(
      selection.method = "vst",
      nfeatures = 2000
    ) %>%
    ScaleData(vars.to.regress = regress_vars)
  
  # Normalize ADT data
  if (adt_assay %in% names(res)) {
    res <- res %>%
      NormalizeData(
        assay = adt_assay,
        normalization.method = "CLR"
      ) %>%
      ScaleData(assay = adt_assay)
  }
  
  res
}

# Scale data, run PCA, run UMAP for gene expression data
run_UMAP_RNA <- function(sobj_in, assay = "RNA", dims = 1:40, prefix = "",
                         pca_meta = T, umap_meta = T, ...) {
  
  # Reduction keys
  red_name = str_c(prefix, "umap")
  red_key  = str_c(prefix, "UMAP_")
  
  # Scale data, run PCA, run UMAP
  # By default only variable features are used for PCA
  res <- sobj_in %>%  
    RunPCA(assay = assay, ...) %>%
    RunUMAP(
      assay          = assay,
      dims           = dims,
      reduction.name = red_name,
      reduction.key  = red_key
    )
  
  # Add PCA to meta.data
  if (pca_meta) {
    res <- res %>%
      AddMetaData(
        metadata = FetchData(., c("PC_1", "PC_2")),
        col.name = str_c(prefix, c("PC_1", "PC_2"))
      )
  }
  
  # Add UMAP to meta.data
  if (umap_meta) {
    umap_columns = str_c(red_key, c(1, 2))
    
    res <- res %>%
      AddMetaData(
        metadata = Embeddings(., reduction = red_name),
        col.name = umap_columns
      )
  }
  
  res
}

# Run PCA, cluster, run UMAP, cluster cells 
cluster_RNA <- function(sobj_in, assay = "RNA", resolution = 0.6, dims = 1:40, 
                        prefix = "", pca_meta = T, umap_meta = T, ...) {
  # Use FindNeighbors to construct a K-nearest neighbors graph based on the euclidean distance in 
  # PCA space, and refine the edge weights between any two cells based on the
  # shared overlap in their local neighborhoods (Jaccard similarity).
  # Use FindClusters to apply modularity optimization techniques such as the Louvain algorithm 
  # (default) or SLM, to iteratively group cells together
  
  # Scale data, run PCA, run UMAP
  res <- sobj_in %>%
    run_UMAP_RNA(
      assay     = assay,
      prefix    = prefix,
      dims      = dims,
      pca_meta  = pca_meta,
      umap_meta = umap_meta,
      ...
    )
  
  # Create nearest neighbors graph and find clusters
  res <- res %>%
    FindNeighbors(
      assay     = assay,
      reduction = "pca",
      dims      = dims
    ) %>%
    FindClusters(
      resolution = resolution,
      verbose = F
    ) %>%
    AddMetaData(
      metadata = Idents(.),
      col.name = str_c(assay, "_clusters")
    )
  
  res
}

# Calculate feature fold change with median control group signal
calc_feat_fc <- function(sobj_in, feat = "adt_ovalbumin", data_slot = "counts", add_pseudo = F,
                         fc_column = "ova_fc", grp_column = "cell_type", control_grps = c("B cell", "T cell")) {
  
  res <- sobj_in %>%
    AddMetaData(FetchData(
      object = ., 
      vars   = feat,
      slot   = data_slot
    ))
  
  res@meta.data <- res@meta.data %>%
    rownames_to_column("cell_id") %>%
    mutate(con_grp = if_else(!!sym(grp_column) %in% control_grps, T, F)) %>%
    group_by(con_grp) %>%
    mutate(con_med = ifelse(con_grp, median(!!sym(feat)), NA)) %>%
    ungroup() %>%
    mutate(
      con_med = max(con_med, na.rm = T),
      !!sym(fc_column) := adt_ovalbumin / con_med
    ) %>%
    dplyr::select(-con_grp, -con_med) %>%
    column_to_rownames("cell_id")
  
  
  # Add pseudo count
  if (0 %in% pull(res@meta.data, fc_column)) {
    res@meta.data <- res@meta.data %>%
      rownames_to_column("cell_id") %>%
      mutate(
        pseudo = ifelse(!!sym(fc_column) > 0, !!sym(fc_column), NA),
        pseudo = min(pseudo, na.rm = T) * 0.5,
        !!sym(fc_column) := !!sym(fc_column) + pseudo
      ) %>%
      column_to_rownames("cell_id")
  }
  
  res
}

# Subset Seurat objects for plotting
subset_sobj <- function(sobj_in, name, cell_types, type_column = "cell_type", dims = 1:40,
                        regress_vars = NULL, ...) {
  
  # Filter cells based on input cell type
  res <- sobj_in %>%
    subset(subset = !!sym(type_column) %in% cell_types)
  
  # Score cell cycle genes
  s.genes <- cc.genes$s.genes %>%
    str_to_title()
  
  g2m.genes <- cc.genes$g2m.genes %>%
    str_to_title()
  
  res <- res %>%
    CellCycleScoring(
      s.features   = s.genes,
      g2m.features = g2m.genes
    )
  
  # Re-process object
  res <- res %>%
    FindVariableFeatures(
      selection.method = "vst",
      nfeatures = 2000
    ) %>%
    ScaleData(vars.to.regress = regress_vars) %>%
    RunPCA() %>%
    RunUMAP(dims = dims)
  
  res
}

# Re-cluster and run clustify
run_clustifyr <- function(sobj_in, type_in = NULL, ref, assay = "RNA", resolution = 1.8, dims = 1:40, type_column = "cell_type",
                          subtype_column = "subtype", clust_column = "subtype_cluster", prefix = NULL, ...) {
  
  res <- sobj_in %>%
    FindNeighbors(
      assay     = assay,
      reduction = "pca",
      dims      = dims
    ) %>%
    FindClusters(
      resolution = resolution,
      verbose = F
    ) %>%
    AddMetaData(
      metadata = Idents(.),
      col.name = "type_clusters"
    ) %>%
    clustify(
      ref_mat       = ref,
      cluster_col   = "type_clusters",
      rename_prefix = prefix,
      ...
    )
  
  # Set cell subtypes for cells that match type_in
  if (!is.null(type_in)) {
    new_col <- "type"
    
    if (!is.null(prefix)) {
      new_col <- str_c(prefix, "_", new_col)
    }
    
    type_col  <- sym(type_column)
    sub_col   <- sym(subtype_column)
    clust_col <- sym(clust_column)
    new_col   <- sym(new_col)
    
    res@meta.data <- res@meta.data %>%
      rownames_to_column("cell_id") %>%
      mutate(
        !!sub_col   := if_else(!!type_col == type_in, !!new_col, !!type_col),
        !!sub_col   := str_to_title_v2(!!sub_col),
        !!clust_col := if_else(!!type_col == type_in, str_c(type_clusters, "-", !!sub_col), !!sub_col)
      ) %>%
      column_to_rownames("cell_id")
  }
  
  res
}

# Filter list of Seurat objects for patient, normalize and merge objects 
merge_sobj <- function(sobj_list, sample_order = NULL) {

  res <- merge(
    x = sobj_list[[1]],
    y = sobj_list[2:length(sobj_list)],
    add.cell.ids = names(sobj_list)
  ) %>%
    ScaleData(assay = "RNA") %>%
    ScaleData(assay = "adt") %>%
    FindVariableFeatures(assay = "RNA")
  
  # Set sample order
  res@meta.data <- res@meta.data %>%
    rownames_to_column("cell_ids") %>%
    mutate(orig.ident = fct_relevel(orig.ident, sample_order)) %>%
    column_to_rownames("cell_ids")
  
  res
}

# Fit gaussian mixture model for given signal
fit_GMM <- function(sobj_in, data_column = "adt_ovalbumin", data_slot = "counts", prob = 0.5, quiet = F) {
  
  set.seed(42)
  
  # Fit GMM for OVA signal
  data_df <- sobj_in %>%
    FetchData(data_column, slot = data_slot) %>%
    rownames_to_column("cell_id") %>%
    mutate(!!sym(data_column) := !!sym(data_column) + 1) %>%
    column_to_rownames("cell_id")
  
  quiet_EM <- quietly(~ normalmixEM(.))
  
  if (!quiet) {
    quiet_EM <- normalmixEM
  }
  
  mixmdl <- data_df %>%
    pull(data_column) %>%
    quiet_EM()
  
  if (quiet) {
    mixmdl <- mixmdl$result
  }
  
  # New column names
  ova_names <- c("Low", "High")
  comp_names <- c("comp.1", "comp.2")
  
  if (mixmdl$mu[1] > mixmdl$mu[2]) {
    ova_names <- rev(ova_names)
  }
  
  names(comp_names)    <- ova_names
  names(mixmdl$mu)     <- ova_names
  names(mixmdl$sigma)  <- ova_names
  names(mixmdl$lambda) <- ova_names

  # Divide into OVA groups
  res <- data.frame(
    cell_id = rownames(data_df),
    data = data_df[, data_column],
    mixmdl$posterior
  ) %>%
    dplyr::rename(!!sym(data_column) := data) %>%
    rename(all_of(comp_names)) %>%
    mutate(GMM_grp = if_else(High >= prob, "High", "Low")) %>%
    column_to_rownames("cell_id")
  
  res <- list(
    res    = res,
    mu     = mixmdl$mu,
    sigma  = mixmdl$sigma,
    lambda = mixmdl$lambda
  )
  
  res
}

# Classify cells based on OVA signal
classify_ova <- function(sobj_in, filt_column = "cell_type", filt = NULL, data_column = "adt_ovalbumin", 
                         data_slot = "counts", return_sobj = T, ...) {
  
  # Filter Seurat object
  sobj_filt <- sobj_in
  
  if (!is.null(filt)) {
    sobj_filt <- sobj_filt %>%
      subset(!!sym(filt_column) == filt)
  }
  
  # Fit GMM
  gmm_res <- sobj_filt %>%
    fit_GMM(
      data_column = data_column,
      data_slot   = data_slot,
      ...
    )
  
  gmm_df <- gmm_res$res %>%
    rownames_to_column("cell_id") %>%
    mutate(
      mu      = gmm_res$mu[GMM_grp],
      GMM_grp = str_to_lower(GMM_grp),
      GMM_grp = str_c("ova ", GMM_grp),
    ) %>%
    select(-!!sym(data_column))
  
  if (!return_sobj) {
    if (!is.null(filt)) {
      gmm_df <- gmm_df %>%
        mutate(!!sym(filt_column) := filt)
    }
    
    return(gmm_df)
  }
  
  # Add OVA groups to meta.data
  gmm_df <- gmm_df %>%
    column_to_rownames("cell_id")
  
  res <- sobj_in %>%
    AddMetaData(gmm_df)
  
  res@meta.data <- res@meta.data %>%
    rownames_to_column("cell_id") %>%
    mutate(GMM_grp = if_else(is.na(GMM_grp), "Other", GMM_grp)) %>%
    column_to_rownames("cell_id")
  
  res
}


# Plotting ----

# Overlay feature data on UMAP or tSNE
# Cannot change number of columns when using FeaturePlot with split.by
plot_features <- function(sobj_in, x = "UMAP_1", y = "UMAP_2", feature, data_slot = "data", 
                          split_id = NULL, pt_size = 0.25, pt_outline = NULL, plot_cols = NULL,
                          feat_levels = NULL, split_levels = NULL, min_pct = NULL, max_pct = NULL, 
                          calc_cor = F, lm_line = F, lab_size = 3.7, lab_pos = c(0.8, 0.9), ...) {
  
  # Format imput data
  counts <- sobj_in
  
  if ("Seurat" %in% class(sobj_in)) {
    vars <- c(x, y, feature)
    
    if (!is.null(split_id)) {
      vars <- c(vars, split_id)
    }

    counts <- sobj_in %>%
      FetchData(vars = unique(vars), slot = data_slot) %>%
      as_tibble(rownames = "cell_ids")
  }
  
  # Rename features
  if (!is.null(names(feature))) {
    counts <- counts %>%
      rename(!!!syms(feature))
    
    feature <- names(feature)
  }
  
  if (!is.null(names(x))) {
    counts <- counts %>%
      rename(!!!syms(x))
    
    x <- names(x)
  }
  
  if (!is.null(names(y))) {
    counts <- counts %>%
      rename(!!!syms(y))
    
    y <- names(y)
  }
  
  # Set min and max values for feature
  if (!is.null(min_pct) || !is.null(max_pct)) {
    counts <- counts %>%
      mutate(
        pct_rank = percent_rank(!!sym(feature)),
        max_val  = ifelse(pct_rank > max_pct, !!sym(feature), NA),
        max_val  = min(max_val, na.rm = T),
        min_val  = ifelse(pct_rank < min_pct, !!sym(feature), NA),
        min_val  = max(min_val, na.rm = T),
        !!sym(feature) := if_else(!!sym(feature) > max_val, max_val, !!sym(feature)),
        !!sym(feature) := if_else(!!sym(feature) < min_val, min_val, !!sym(feature))
      )
  }

  # Set feature order
  if (!is.null(feat_levels)) {
    counts <- counts %>%
      mutate(!!sym(feature) := fct_relevel(!!sym(feature), feat_levels))
  }
  
  # Set facet order
  if (!is.null(split_id) && length(split_id) == 1) {
    counts <- counts %>%
      mutate(split_id = !!sym(split_id))
    
    if (!is.null(split_levels)) {
      counts <- counts %>%
        mutate(split_id = fct_relevel(split_id, split_levels))
    }
  }
  
  # Calculate correlation
  if (calc_cor) {
    if (!is.null(split_id)) {
      counts <- counts %>%
        group_by(split_id)
    }
    
    counts <- counts %>%
      mutate(
        cor_lab = cor(!!sym(x), !!sym(y)),
        cor_lab = round(cor_lab, digits = 2),
        cor_lab = str_c("r = ", cor_lab),
        min_x   = min(!!sym(x)),
        max_x   = max(!!sym(x)),
        min_y   = min(!!sym(y)),
        max_y   = max(!!sym(y)),
        lab_x   = (max_x - min_x) * lab_pos[1] + min_x,
        lab_y   = (max_y - min_y) * lab_pos[1] + min_y
      )
  }
  
  # Create scatter plot
  # To add outline for each cluster create separate layers
  res <- counts %>%
    arrange(!!sym(feature))
  
  if (!is.null(pt_outline)) {
    
    if (!is.numeric(counts[[feature]])) {
      res <- res %>%
        ggplot(aes(!!sym(x), !!sym(y), color = !!sym(feature), fill = !!sym(feature)))
      
      feats <- counts[[feature]] %>%
        unique()
      
      if (!is.null(feat_levels)) {
        feats <- feat_levels[feat_levels %in% feats]
      }
      
      for (feat in feats) {
        f_counts <- counts %>%
          filter(!!sym(feature) == feat)
        
        res <- res +
          geom_point(data = f_counts, aes(fill = !!sym(feature)), size = pt_outline, color = "black", show.legend = F) +
          geom_point(data = f_counts, size = pt_size)
      }
      
    } else {
      res <- res %>%
        ggplot(aes(!!sym(x), !!sym(y), color = !!sym(feature))) +
        geom_point(aes(fill = !!sym(feature)), size = pt_outline, color = "black", show.legend = F) +
        geom_point(size = pt_size)
    }
    
  } else {
    res <- res %>%
      ggplot(aes(!!sym(x), !!sym(y), color = !!sym(feature))) +
      geom_point(size = pt_size)
  }

  # Add regression line
  if (lm_line) {
    res <- res +
      geom_smooth(method = "lm", se = F, color = "black", size = 0.5, linetype = 2)
  }
  
  # Add correlation coefficient label
  if (calc_cor) {
    res <- res +
      geom_text(
        aes(x = lab_x, lab_y, label = cor_lab),
        color = "black",
        size  = lab_size,
        check_overlap = T, 
        show.legend = F
      )
  }
  
  # Set feature colors
  if (!is.null(plot_cols)) {
    if (is.numeric(counts[[feature]])) {
      res <- res +
        scale_color_gradientn(colors = plot_cols)
        # scale_color_gradient(low = plot_cols[1], high = plot_cols[2])

    } else {
      res <- res +
        scale_color_manual(values = plot_cols) +
        scale_fill_manual(values = plot_cols)
    }
  }
  
  # Split plot into facets
  if (!is.null(split_id)) {
    if (length(split_id) == 1) {
      res <- res +
        facet_wrap(~ split_id, ...)
      
    } else if (length(split_id) == 2) {
      eq <- str_c(split_id[1], " ~ ", split_id[2])
      
      res <- res +
        facet_grid(as.formula(eq), ...)
    }
  }
  
  res
}

# Create GO bubble plot
create_bubbles <- function(GO_df, plot_colors = theme_cols[c(1:2, 4, 9)],
                           n_terms = 15) {

  # Check for empty inputs
  if (is_empty(GO_df) || nrow(GO_df) == 0) {
    res <- ggplot() +
      geom_blank()
    
    return(res)
  }
  
  # Shorten GO terms and database names
  GO_data <- GO_df %>%
    mutate(
      term_id = str_remove(term_id, "(GO|KEGG):"),
      term_id = str_c(term_id, " ", term_name),
      term_id = str_to_lower(term_id),
      term_id = str_trunc(term_id, 40, "right"),
      source  = fct_recode(
        source,
        "Biological\nProcess" = "GO:BP",
        "Cellular\nComponent" = "GO:CC",
        "Molecular\nFunction" = "GO:MF",
        "KEGG"                = "KEGG"
      )
    )
  
  # Reorder database names
  plot_levels <- c(
    "Biological\nProcess",
    "Cellular\nComponent",
    "Molecular\nFunction",
    "KEGG"
  )
  
  GO_data <- GO_data %>%
    mutate(source = fct_relevel(source, plot_levels))
  
  # Extract top terms for each database
  top_GO <- GO_data %>%
    group_by(source) %>%
    arrange(p_value) %>%
    dplyr::slice(1:n_terms) %>%
    ungroup()
  
  # Create bubble plots
  res <- GO_data %>%
    ggplot(aes(1.25, -log10(p_value), size = intersection_size)) +
    geom_point(color = plot_colors, alpha = 0.5, show.legend = T) +
    geom_text_repel(
      aes(2, -log10(p_value), label = term_id),
      data         = top_GO,
      size         = 2.3,
      direction    = "y",
      hjust        = 0,
      segment.size = NA
    ) +
    xlim(1, 8) +
    labs(y = "-log10(p-value)") +
    theme_info +
    theme(
      axis.title.x    = element_blank(),
      axis.text.x     = element_blank(),
      axis.ticks.x    = element_blank()
    ) +
    facet_wrap(~ source, scales = "free", nrow = 1)
  
  res
}

# Plot percentage of cells in given groups
plot_cell_count <- function(sobj_in, group_id, split_id = NULL, group_order = NULL, fill_id, 
                            plot_cols = NULL, x_lab = "Cell type", y_lab = "Fraction of cells",
                            bar_pos = "fill", order_count = T, bar_line = 0, ...) {
  
  res <- sobj_in
  
  if ("Seurat" %in% class(sobj_in)) {
    res <- sobj_in@meta.data %>%
      rownames_to_column("cell_id")
  }
  
  res <- res %>%
    mutate(
      group_id := !!sym(group_id),
      fill_id  := !!sym(fill_id)
    )
  
  if (!is.null(group_order)) {
    res <- res %>%
      mutate(group_id = fct_relevel(group_id, group_order))
  }
  
  if (!is.null(split_id)) {
    res <- res %>%
      mutate(split_id := !!sym(split_id))
  }
  
  if (order_count) {
    res <- res %>%
      mutate(fill_id = fct_reorder(fill_id, cell_id, n_distinct))
  }

  res <- res %>%
    ggplot(aes(group_id, fill = fill_id)) +
    geom_bar(position = bar_pos, size = bar_line, color = "black") +
    labs(x = x_lab, y = y_lab)
  
  if (!is.null(plot_cols)) {
    res <- res +
      scale_fill_manual(values = plot_cols)
  }
  
  if (!is.null(split_id)) {
    res <- res +
      facet_wrap(~ split_id, ...)
  }
  
  res
}


# Differential expression ----

# Find markers with presto
find_markers <- function(sobj_in, group_column = NULL, exclude_clust = NULL, groups_use = NULL, 
                         uniq = params$uniq_markers, FC_min = params$FC_min, auc_min = params$auc_min, 
                         p_max = params$p_max_markers, pct_in_min = params$pct_in_min, 
                         pct_out_max = params$pct_out_max, ...) {
  
  if (!is.null(exclude_clust) && is.null(groups_use)) {
    sobj_in <- sobj_in %>%
      subset(subset = !!sym(group_column) != exclude_clust)
  }
  
  res <- sobj_in %>%
    wilcoxauc(
      group_by = group_column,
      groups_use = groups_use,
      ...
    ) %>%
    as_tibble() %>%
    filter(
      padj    < p_max,
      logFC   > FC_min,
      auc     > auc_min,
      pct_in  > pct_in_min,
      pct_out < pct_out_max
    ) %>%
    arrange(desc(logFC))
  
  if (uniq) {
    res <- res %>%
      group_by(feature) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  res
}

# Find cluster markers for each separate cell type
find_group_markers <- function(input_grp, input_sobj, grp_column, clust_column) {
  
  res <- input_sobj %>%
    subset(!!sym(grp_column) == input_grp)

  clusts <- res@meta.data %>%
    pull(clust_column)

  if (n_distinct(clusts) < 2) {
    return(NULL)
  }

  res <- res %>%
    find_markers(group_column = clust_column) %>%
    mutate(cell_type = input_grp)

  res
}

# Run gprofiler
run_gprofiler <- function(gene_list, genome = NULL, gmt_id = NULL, p_max = params$p_max_GO, 
                          GO_size = params$term_size, intrsct_size = params$intrsct_size, 
                          order_query = params$order_query, dbases = c("GO:BP", "GO:MF", "KEGG"), ...) {
  
  # Check for empty gene list
  if (is_empty(gene_list)) {
    return(as_tibble(NULL))
  }
  
  # Check arguments
  if (is.null(genome) && is.null(gmt_id)) {
    stop("ERROR: Must specifiy genome or gmt_id")
  }
  
  # Use gmt id
  if (!is.null(gmt_id)) {
    genome <- gmt_id
    dbases <- NULL
  }
  
  # Run gProfileR
  res <- gene_list %>%
    gost(
      organism       = genome,
      sources        = dbases,
      domain_scope   = "annotated",
      significant    = T,
      user_threshold = p_max,
      ordered_query  = order_query,
      ...
    )
  
  # Format and sort output data.frame
  res <- as_tibble(res$result)
  
  if (!is_empty(res)) {
    res <- res %>%
      filter(
        term_size > GO_size,
        intersection_size > intrsct_size
      ) %>%
      arrange(source, p_value)
  }
  
  res
}


# Figure panels ----

# Create UMAPs testing different clustering resolutions with clustifyr
create_clust_umaps <- function(sobj_in, ref_in, x = "UMAP_1", y = "UMAP_2", reslns, plot_cols, pt_size = 0.05, pt_outline = 0.06, 
                               threshold = "auto", bar_cols = c(rep("white", 2), "#403164"), panel_heights = c(0.75, 0.75, 1), ...) {
  
  # Remove UMAP columns from meta.data since clustifyr automatically adds these
  if (any(c("UMAP_1", "UMAP_2") %in% colnames(sobj_in@meta.data))) {
    sobj_in@meta.data <- sobj_in@meta.data %>%
      select(-UMAP_1, -UMAP_2)
  }
  
  # Cluster and create data.frame for UMAPs
  cluster_df <- map(reslns, ~ {
    res <- sobj_in %>%
      run_clustifyr(
        ref        = ref_in,
        resolution = .x,
        threshold  = threshold
      )
    
    res <- res@meta.data %>%
      as_tibble(rownames = "cell_id") %>%
      mutate(
        n_clust   = n_distinct(seurat_clusters),
        clust_lab = str_c(n_clust, " (", .x, ")"),
        type      = str_to_title_v2(type)
      ) %>%
      arrange(n_clust) %>%
      mutate(clust_lab = fct_inorder(clust_lab)) %>%
      dplyr::select(
        cell_id, orig.ident, seurat_clusters,
        type, n_clust, clust_lab, r, all_of(c(x, y))
      )
    
    res
  }) %>%
    bind_rows()
  
  # Helper function to create UMAPs
  create_umaps <- function(df_in, x, y, feat, plot_cols = get_cols(70), pt_size = 0.05,
                           pt_outline = 0.06, guide = outline_guide, ...) {
    res <- df_in %>%
      plot_features(
        x          = x,
        y          = y,
        feature    = feat,
        pt_size    = pt_size,
        pt_outline = pt_outline,
        plot_cols  = plot_cols,
        split_id   = "clust_lab",
        nrow       = 1
      ) +
      guides(color = guide) +
      blank_theme +
      theme(
        plot.margin     = unit(c(0.7, 0.2, 0.2, 0.2), "cm"),
        legend.position = "bottom",
        legend.title    = element_blank(),
        legend.text     = element_text(size = 10)
      ) +
      theme(...)
    
    res
  }
  
  # Cluster UMAPs
  umap_guide <- outline_guide
  umap_guide$nrow <- 4
  
  clust_gg <- cluster_df %>%
    create_umaps(
      x               = x,
      y               = y,
      pt_size         = pt_size,
      pt_outline      = pt_outline,
      feat            = "seurat_clusters",
      guide           = umap_guide,
      legend.position = "none",
      ...
    )
  
  # Correlation UMAPs
  bar_guide <- guide_colorbar(frame.colour = "black", frame.linewidth = 0.2, barwidth = unit(0.2, "cm"))
  
  cor_gg <- cluster_df %>%
    create_umaps(
      x               = x,
      y               = y,
      pt_size         = pt_size,
      pt_outline      = pt_outline,
      feat            = "r",
      guide           = bar_guide,
      plot_cols       = bar_cols,
      legend.position = "right",
      legend.title    = element_text(),
      ...
    )
    
  # Cell type UMAPs
  type_gg <- cluster_df %>%
    create_umaps(
      x          = x,
      y          = y,
      pt_size    = pt_size,
      pt_outline = pt_outline,
      feat       = "type",
      guide      = umap_guide,
      plot_cols  = plot_cols,
      ...
    )
  
  # Create final figure
  plot_grid(
    clust_gg, cor_gg, type_gg,
    rel_heights = panel_heights,
    ncol  = 1,
    align = "v",
    axis  = "rl"
  )
}

# Create reference UMAP for comparisons
create_ref_umap <- function(input_sobj, pt_mtplyr = 1, pt_outline = NULL, color_guide, ...) {
  
  if (is.null(pt_outline)) {
    pt_outline <- 0.1 * pt_mtplyr + 0.3
  }
  
  res <- input_sobj %>%
    plot_features(
      pt_size     = 0.1 * pt_mtplyr,
      pt_outline  = pt_outline,
      ...
    ) +
    guides(color = color_guide) +
    blank_theme +
    theme(
      legend.position = "top",
      legend.title    = element_blank(),
      legend.text     = element_text(size = 10)
    )
  
  res
}

# Create UMAPs showing marker gene signal
create_marker_umaps <- function(input_sobj, input_markers, umap_col = NULL, add_outline = NULL, 
                                pt_mtplyr = 1, low_col = "#fafafa") {
  
  # pt_size <- 0.25 * pt_mtplyr
  pt_size <- 0.1 * pt_mtplyr
  
  if (!is.null(umap_col)) {
    input_markers <- set_names(
      x  = rep(umap_col, length(input_markers)),
      nm = input_markers
    )
  }
  
  res <- input_markers %>%
    imap(~ {
      input_sobj %>%
        plot_features(
          feature    = .y,
          plot_cols  = c(low_col, .x),
          pt_outline = add_outline,
          pt_size    = pt_size
        ) +
        ggtitle(.y) +
        blank_theme +
        theme(
          plot.title        = element_text(size = 13),
          legend.position   = "bottom",
          legend.title      = element_blank(),
          legend.text       = element_text(size = 8),
          legend.key.height = unit(0.1, "cm"),
          legend.key.width  = unit(0.3, "cm"),
          axis.title.y      = element_text(size = 13, color = "white"),
          axis.text.y       = element_text(size = 8, color = "white")
        )
    })
  
  res
}

# Create boxplots showing marker gene signal
create_marker_boxes <- function(input_sobj, input_markers, clust_column, box_cols, group = NULL, include_legend = F,
                                all_boxes = F, all_violins = F, order_boxes = T, clust_regex = "\\-[a-zA-Z0-9_ ]+$",
                                n_boxes = 10, median_pt = 0.75, n_rows = 2, pt_mtplyr = 1, exclude_clust = NULL, ...) {
  
  # Retrieve and format data for boxplots
  input_markers <- input_markers %>%
    head(n_boxes)
  
  box_data <- input_sobj %>%
    FetchData(c(clust_column, input_markers)) %>%
    as_tibble(rownames = "cell_id") %>%
    filter(!(!!sym(clust_column) %in% exclude_clust)) %>%
    mutate(grp = str_remove(!!sym(clust_column), clust_regex))
  
  input_markers <- input_markers %>%
    str_trunc(9)
  
  # Filter based on input group
  if (!is.null(group)) {
    box_data <- box_data %>%
      filter(grp == group)
  }
  
  # Format data for plots
  box_data <- box_data %>%
    pivot_longer(cols = c(-cell_id, -grp, -!!sym(clust_column)), names_to = "key", values_to = "Counts") %>%
    mutate(
      !!sym(clust_column) := fct_relevel(!!sym(clust_column), names(box_cols)),
      key = str_trunc(key, width = 9, side = "right"),
      key = fct_relevel(key, input_markers)
    )
  
  # Order boxes by mean signal
  if (order_boxes) {
    box_data <- box_data %>%
      mutate(!!sym(clust_column) := fct_reorder(!!sym(clust_column), Counts, mean, .desc = T))
  }
  
  n_clust <- box_data %>%
    pull(clust_column) %>%
    n_distinct()
  
  # Create plots
  n_cols <- ceiling(n_boxes / n_rows)
  
  res <- box_data %>%
    ggplot(aes(!!sym(clust_column), Counts, fill = !!sym(clust_column))) + 
    facet_wrap(~ key, ncol = n_cols) +
    scale_color_manual(values = box_cols) +
    theme_info +
    theme(
      panel.spacing.x  = unit(0.7, "cm"),
      strip.background = element_blank(),
      strip.text       = element_text(size = 13),
      legend.position  = "none",
      axis.title.x     = element_blank(),
      axis.title.y     = element_text(size = 13),
      axis.text.x      = element_blank(),
      axis.text.y      = element_text(size = 8),
      axis.ticks.x     = element_blank(),
      axis.line.x      = element_blank()
    )
  
  # Adjust output plot type
  if (n_clust > 6 || all_boxes) {
    res <- res +
      stat_summary(geom = "point", shape = 22, fun = median, size = 0) +
      stat_summary(geom = "point", shape = 22, fun = median, size = median_pt * 2, color = "black") +
      geom_boxplot(
        color          = "white",
        fill           = "white",
        alpha          = 1,
        size           = 0.3,
        outlier.colour = "white",
        outlier.alpha  = 1,
        outlier.size   = 0.1,
        coef           = 0       # To exclude whiskers
      ) +
      geom_boxplot(
        size           = 0.3,
        outlier.colour = "grey85",
        outlier.alpha  = 1,
        outlier.size   = 0.1,
        show.legend    = F,
        coef           = 0,
        fatten = 0
      ) +
      stat_summary(
        aes(color = !!sym(clust_column)),
        geom        = "point",
        shape       = 22,
        fun         = median,
        size        = median_pt,
        stroke      = 0.75,
        fill        = "white",
        show.legend = F
      ) +
      guides(fill = guide_legend(override.aes = list(size = 3.5, stroke = 0.25))) +
      scale_fill_manual(values = box_cols) +
      theme(
        panel.background = element_rect(color = "#fafafa", fill = "#fafafa"),
        panel.spacing.x  = unit(0.2, "cm")
      ) +
      theme(...)
      
  } else if (all_violins) {
    res <- res +
      geom_violin(aes(fill = !!sym(clust_column)), size = 0.2) +
      stat_summary(
        aes(color = !!sym(clust_column)),
        geom   = "point",
        shape  = 22,
        fun    = median,
        size   = median_pt,
        stroke = 0.75,
        fill   = "white"
      ) +
      scale_fill_manual(values = box_cols) +
      scale_color_manual(values = box_cols) +
      theme(
        panel.background = element_rect(color = "#fafafa", fill = "#fafafa"),
        panel.spacing.x  = unit(0.2, "cm")
      ) +
      theme(...)
    
  } else {
    pt_size <- 0.3 * pt_mtplyr
    
    res <- res +
      geom_quasirandom(size = pt_size) +
      theme(...)
  }
  
  # Add legend
  if (include_legend) {
    res <- res +
      guides(color = col_guide) +
      theme(legend.position = "top")
  }
  
  # Add blank space for missing facets
  n_keys <- n_distinct(box_data$key)
  
  if (n_keys <= n_cols && n_rows > 1) {
    n_keys <- if_else(n_keys == 1, 2, as.double(n_keys))
    n_cols <- floor(n_cols / n_keys)
    
    res <- res %>%
      plot_grid(
        ncol = n_cols,
        nrow = 2
      )
  }
  
  res
}

# Create figure summarizing marker genes
create_marker_fig <- function(input_sobj, input_markers, input_GO, clust_column, 
                              input_umap, umap_color, fig_heights = c(0.46, 0.3, 0.3), 
                              GO_genome = params$genome, box_colors, n_boxes = 10,
                              umap_outline = NULL, umap_mtplyr = 1, xlsx_name = NULL, 
                              sheet_name = NULL, ...) {
  
  marks_umap <- marks_boxes <- GO_bubbles <- ggplot() +
    geom_blank() +
    theme_void()
  
  # Create UMAPs showing marker gene signal
  if (nrow(input_markers) > 0) {
    top_marks <- input_markers$feature %>%
      head(n_boxes)
    
    clust_legend <- get_legend(input_umap)
    
    input_umap <- input_umap +
      theme(legend.position = "none")
    
    marks_umap <- input_sobj %>%
      create_marker_umaps(
        input_markers = head(top_marks, 7),
        umap_col      = umap_color,
        add_outline   = umap_outline,
        pt_mtplyr     = umap_mtplyr
      ) %>%
      append(list(input_umap), .)
    
    marks_umap <- plot_grid(
      plotlist = marks_umap,
      ncol     = 4,
      nrow     = 2,
      align    = "vh",
      axis     = "trbl"
    )
    
    marks_umap <- plot_grid(
      clust_legend, marks_umap,
      rel_heights = c(0.2, 0.9),
      nrow = 2
    )
    
    # Create boxplots showing marker gene signal
    marks_boxes <- input_sobj %>%
      create_marker_boxes(
        input_markers = top_marks,
        clust_column  = clust_column,
        box_cols      = box_colors,
        n_boxes       = n_boxes,
        plot.margin   = unit(c(0.8, 0.2, 0.2, 0.2), "cm"),
        ...
      )
    
    # Create GO term plots
    if (nrow(input_GO) > 0) {
      GO_bubbles <- input_GO %>%
        create_bubbles(plot_colors = umap_color) +
        theme(
          plot.margin      = unit(c(0.8, 0.2, 0.2, 0.2), "cm"),
          strip.background = element_blank(),
          strip.text       = element_text(size = 13),
          axis.title.y     = element_text(size = 13),
          axis.text.y      = element_text(size = 8),
          axis.line.x      = element_blank(),
          legend.position  = "bottom",
          legend.title     = element_blank(),
          legend.text      = element_text(size = 8)
        )
      
      # Write GO terms to excel file 
      if (!is.null(xlsx_name)) {
        input_GO %>%
          dplyr::select(
            term_name,  term_id,
            source,     effective_domain_size,
            query_size, intersection_size,
            p_value,    significant 
          ) %>%
          arrange(source, p_value) %>%
          write.xlsx(
            file      = str_c(xlsx_name, "_GO.xlsx"),
            sheetName = sheet_name,
            append    = T
          )
      }
    }
    
    # Write markers to excel file
    if (!is.null(xlsx_name)) {
      input_markers %>%
        write.xlsx(
          file      = str_c(xlsx_name, "_markers.xlsx"),
          sheetName = sheet_name,
          append    = T
        )
    }
  }
  
  # Create final figure
  res <- plot_grid(
    marks_umap, marks_boxes, GO_bubbles,
    rel_heights = fig_heights,
    ncol        = 1
  )
  
  res
}

# Filter clusters and set cluster order
set_cluster_order <- function(input_cols, input_marks, n_cutoff = 5) {
  input_marks <- input_marks %>%
    group_by(group) %>%
    filter(n() >= n_cutoff) %>%
    ungroup()
  
  marks <- unique(input_marks$group)
  res   <- names(input_cols)
  res   <- res[res %in% marks]
  
  res
}

# Create v1 panel for marker genes
create_marker_panel_v1 <- function(input_sobj, input_cols, input_umap = NULL, clust_column, order_boxes = T,
                                   color_guide = guide_legend(override.aes = list(size = 3.5, shape = 16)),
                                   uniq = params$uniq_GO, umap_mtplyr = 6, xlsx_name = NULL, exclude_clust = NULL,
                                   groups_use = NULL, ...) {
  
  # Set point size
  # ref_mtplyr <- if_else(umap_mtplyr == 1, umap_mtplyr, umap_mtplyr * 2.5)
  umap_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr, 1)
  ref_mtplyr <- umap_mtplyr
  
  # Find marker genes
  markers <- input_sobj %>%
    find_markers(
      group_column  = clust_column,
      groups_use    = groups_use,
      exclude_clust = exclude_clust
    )
  
  # Find GO terms
  GO_df <- markers
  
  if (nrow(markers) > 0) {
    GO_df <- markers %>%
      group_by(group) %>%
      do({
        arrange(., desc(logFC)) %>%
          pull(feature) %>%
          run_gprofiler(genome = params$genome)
      }) %>%
      ungroup()
  }
  
  if (uniq && nrow(GO_df) > 0) {
    GO_df <- GO_df %>%
      group_by(term_id) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  # Set cluster order based on order of input_cols
  fig_clusters <- input_cols %>%
    set_cluster_order(markers)
  
  fig_clusters <- fig_clusters[!fig_clusters %in% exclude_clust]
  
  # Create figures
  for (i in seq_along(fig_clusters)) {
    cat("\n#### ", fig_clusters[i], "\n", sep = "")
    
    # Filter markers and GO terms
    clust <- fig_clusters[i]
    
    fig_marks <- markers %>%
      filter(group == clust)
    
    fig_GO <- GO_df %>%
      filter(group == clust)
    
    # Create reference umap
    ref_umap <- input_umap
    umap_col <- input_cols[clust]
    
    if (is.null(input_umap)) {
      umap_levels <- input_cols[names(input_cols) != clust]
      umap_levels <- names(c(umap_levels, umap_col))
      
      ref_umap <- input_sobj %>%
        create_ref_umap(
          feature     = clust_column,
          plot_cols   = input_cols,
          feat_levels = umap_levels,
          pt_mtplyr   = ref_mtplyr,
          color_guide = color_guide
        )
    }
    
    # Create panel
    marker_fig <- input_sobj %>%
      create_marker_fig(
        input_markers = fig_marks,
        input_GO      = fig_GO,
        clust_column  = clust_column,
        input_umap    = ref_umap,
        umap_color    = umap_col,
        box_colors    = input_cols,
        order_boxes   = order_boxes,
        umap_mtplyr   = umap_mtplyr,
        xlsx_name     = xlsx_name,
        sheet_name    = clust,
        exclude_clust = exclude_clust,
        ...
      )
    
    cat(nrow(fig_marks), "marker genes and", nrow(fig_GO), "GO terms were identified.")
    print(marker_fig)
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  }
}

# Create v2 panel that splits plots into groups
create_marker_panel_v2 <- function(input_sobj, input_markers, input_cols, grp_column, clust_column,
                                   color_guide = guide_legend(override.aes = list(size = 3.5, shape = 16)), 
                                   uniq_GO = params$uniq_GO, umap_mtplyr = 6, xlsx_name = NULL, 
                                   clust_regex = "\\-[a-zA-Z0-9_ ]+$", ...) {
  
  # Set point size
  # ref_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr * 2.5, 1)
  umap_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr, 1)
  ref_mtplyr <- umap_mtplyr
  
  # Figure colors and order
  fig_clusters <- input_cols %>%
    set_cluster_order(input_markers)
  
  # Find GO terms
  GO_df <- input_markers %>%
    group_by(group) %>%
    do({
      arrange(., desc(logFC)) %>%
        pull(feature) %>%
        run_gprofiler(genome = params$genome)
    }) %>%
    ungroup()
  
  if (uniq_GO && nrow(GO_df) > 0) {
    GO_df <- GO_df %>%
      group_by(term_id) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  # Create figures
  for (i in seq_along(fig_clusters)) {
    cat("\n#### ", fig_clusters[i], "\n", sep = "")
    
    # Filter markers and GO terms
    clust <- fig_clusters[i]
    
    fig_marks <- input_markers %>%
      filter(group == clust)
    
    fig_GO <- GO_df %>%
      filter(group == clust)
    
    # Set colors
    umap_col <- input_cols[clust]
    
    group <- clust %>%
      str_remove(clust_regex)
    
    grp_regex <- str_c("^", group, "-") %>%
      str_replace("\\+", "\\\\+")            # include this to escape "+" in names
    
    fig_cols <- input_cols[grepl(grp_regex, names(input_cols))]
    fig_cols <- c( "Other" = "#fafafa", fig_cols)
    ref_cols <- fig_cols[names(fig_cols) != clust]
    ref_cols <- c(ref_cols, umap_col)
    
    # Create reference UMAP
    ref_umap <- input_sobj %>%
      FetchData(c("UMAP_1", "UMAP_2", grp_column, clust_column)) %>%
      as_tibble(rownames = "cell_id") %>%
      mutate(!!sym(clust_column) := if_else(
        !!sym(grp_column) != group, 
        "Other", 
        !!sym(clust_column)
      )) %>%
      create_ref_umap(
        feature     = clust_column,
        plot_cols   = ref_cols,
        feat_levels = names(ref_cols),
        pt_mtplyr   = ref_mtplyr,
        color_guide = color_guide
      )
    
    # Create panel
    marker_fig <- input_sobj %>%
      create_marker_fig(
        input_markers = fig_marks,
        input_GO      = fig_GO,
        clust_column  = clust_column,
        input_umap    = ref_umap,
        umap_color    = umap_col,
        box_colors    = fig_cols,
        group         = group,
        umap_mtplyr   = umap_mtplyr,
        xlsx_name     = xlsx_name,
        sheet_name    = clust,
        ...
      )
    
    cat(nrow(fig_marks), "marker genes were identified.", nrow(fig_GO), "GO terms were identified.")
    print(marker_fig)
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  }
}


# Load packages
R_packages <- c(
  "tidyverse",  "Seurat",
  "gprofiler2", "knitr",
  "cowplot",    "ggbeeswarm",
  "ggrepel",    "RColorBrewer",
  "xlsx",       "colorblindr",
  "ggforce",    "broom",
  "mixtools",   "clustifyr",
  "boot",       "scales",
  "patchwork",  "ComplexHeatmap",
  "devtools",   "broom",
  "presto",     "here",
  "clustifyrdata"
)

purrr::walk(R_packages, library, character.only = T)


# Default chunk options
knitr::opts_chunk$set(message = F, warning = F)

# Set paths/names
mat_paths <- params$sobjs %>%
  pull_nest_vec("mat_in") %>%
  map_chr(here)

so_out <- params$sobjs %>%
  pull_nest_vec("sobj_out") %>%
  map_chr(~ here(params$rds_dir, .x))

so_types <- params$sobjs %>%
  pull_nest_vec("cell_type")

so_titles <- params$sobjs %>%
  pull_nest_vec("title")

type_ref_path <- here(params$ref_dir, params$type_ref)

subtype_ref_paths <- params$subtype_refs %>%
  map_chr(~ here(params$ref_dir, .x))

# Load clustifyr refs?
load_refs <- all(file.exists(c(type_ref_path, subtype_ref_paths)))

# Load Seurat objects?
load_sobjs <- all(file.exists(so_out))

# Clustering parameters
type_res    <- 0.8
type_thresh <- 0.5

subtype_res <- c(
  d2_DC   = 1.6,
  d14_DC  = 1.6,
  d2_LEC  = 1,
  d14_LEC = 1,
  d2_FRC  = 1.6,
  d14_FRC = 1.6
)

subtype_thresh <- 0.5


# Parameters
type_column <- "cell_type1"
subtype_column <- "cell_type2"

# Load Seurat objects for refs
ref_sobjs <- params$ref_sobjs %>%
  here(params$rds_dir, .) %>%
  map(read_rds)

ref_sobj <- merge(ref_sobjs[[1]], ref_sobjs[2:length(ref_sobjs)])

# Combine CCR7hi XCR- Mig cDC2s and CCR7hi Mig cDC2s
ref_sobj@meta.data <- ref_sobj@meta.data %>%
  rownames_to_column("cell_id") %>%
  mutate(
    !!sym(subtype_column) := str_replace(!!sym(subtype_column), "^CCR7hi XCR1- mig cDC2$", "CCR7hi mig cDC2")
  ) %>%
  column_to_rownames("cell_id")

# Cell type refs
type_ref <- ref_sobj %>%
  seurat_ref(cluster_col = type_column)

# Cell subtype refs
# Create separate subtype ref for each cell type
cell_types <- ref_sobj@meta.data %>%
  pull(type_column) %>%
  unique()

subtype_refs <- cell_types %>%
  set_names(., .) %>%
  map(~ {
    ref_sobj %>%
      subset(subset = !!sym(type_column) == .x) %>%
      seurat_ref(cluster_col = subtype_column)
  })

# Load Xiang et al. LEC ref
LEC_ref <- str_remove(params$xiang_ref, "\\.rda$")

load(here(params$ref_dir, params$xiang_ref))

LEC_ref <- eval(parse(text = LEC_ref))

# Add Immgen endothelial refs
immgen_LEC <- immgen_ref[, grepl("Endothelial", colnames(immgen_ref))]
immgen_LEC <- immgen_LEC[rownames(immgen_LEC) %in% rownames(LEC_ref), ]
colnames(immgen_LEC) <- colnames(immgen_LEC) %>%
  str_replace("Endothelial cells \\(BEC\\)", "BEC")

LEC_ref <- LEC_ref[rownames(immgen_LEC), ]

if (!identical(rownames(immgen_LEC), rownames(LEC_ref))) {
  stop("LEC reference rownames do not match.")
}

subtype_refs$LEC <- cbind(LEC_ref, immgen_LEC)

# Save DC and FRC ref matrices
params$subtype_refs %>%
  iwalk(~ {
    ref <- subtype_refs[[.y]]
    name <- str_remove(.x, "\\.rda$")
    
    assign(name, ref)
    
    save(list = name, file = here(params$ref_dir, .x))
  })

save(type_ref, file = here(params$ref_dir, "ref_celltype_walsh.rda"))


# Create Seurat objects ----

create_sobjs_01 <- function(paths_in, resolution) {
  
  # Avoid loading same matrix twice
  sobjs_raw <- unique(paths_in) %>%
    set_names(., .) %>%
    map(create_sobj, adt_count_min = 0)
  
  # Normalize and cluster
  res <- sobjs_raw %>%
    map(~ {
      .x %>%
        norm_sobj() %>%
        cluster_RNA(
          assay      = "RNA",
          resolution = resolution,
          pca_meta   = F,
          umap_meta  = F 
        )
    })
  
  res
}

sobjs_raw <- unique(paths_in) %>%
  load_sobjs_01(resolution = type_res)

# Avoid loading same matrix twice
# sobjs_raw <- unique(mat_paths) %>%
#   set_names(., .) %>%
#   map(create_sobj, adt_count_min = 0)

# Normalize and cluster
# sobjs_raw <- sobjs_raw %>%
#   map(~ {
#     .x %>%
#       norm_sobj() %>%
#       cluster_RNA(
#         assay      = "RNA",
#         resolution = type_res,
#         pca_meta   = F,
#         umap_meta  = F 
#       )
#   })


# Annotate cell types ----

assign_cell_types_02 <- function(sobjs_in, ref_mat, threshold) {
  
  # Assign cell types
  res <- sobjs_in %>%
    map(~ {
      res <- clustify(
        input         = .x,
        cluster_col   = "RNA_clusters",
        ref_mat       = ref_mat,
        rename_prefix = "t1",
        seurat_out    = T,
        threshold     = threshold
      )
      
      res@meta.data <- res@meta.data %>%
        rownames_to_column("cell_id") %>%
        mutate(cell_type = t1_type) %>%
        select(-UMAP_1, -UMAP_2) %>%
        column_to_rownames("cell_id")
      
      res
    })
  
  # Calculate OVA fold change
  res <- res %>%
    map(~ {
      calc_feat_fc(
        sobj_in      = .x,
        feat         = "adt_ovalbumin",
        data_slot    = "counts",
        grp_column   = "cell_type",
        control_grps = c("B cell", "T cell"),
        add_pseudo   = T
      )
    })
  
  res
}

sobjs_raw <- sobjs_raw %>%
  assign_cell_types_02(
    ref_mat = type_ref,
    threshold = type_thresh
  )

# # Assign cell types
# sobjs_raw <- sobjs_raw %>%
#   map(~ {
#     res <- clustify(
#       input         = .x,
#       cluster_col   = "RNA_clusters",
#       ref_mat       = type_ref,
#       rename_prefix = "t1",
#       seurat_out    = T,
#       threshold     = type_thresh
#     )
#     
#     res@meta.data <- res@meta.data %>%
#       rownames_to_column("cell_id") %>%
#       mutate(cell_type = t1_type) %>%
#       select(-UMAP_1, -UMAP_2) %>%
#       column_to_rownames("cell_id")
#     
#     res
#   })
# 
# # Calculate OVA fold change
# sobjs_raw <- sobjs_raw %>%
#   map(~ {
#     calc_feat_fc(
#       sobj_in      = .x,
#       feat         = "adt_ovalbumin",
#       data_slot    = "counts",
#       grp_column   = "cell_type",
#       control_grps = c("B cell", "T cell"),
#       add_pseudo   = T
#     )
#   })


# Expand Seurat object list for subsets ----
sobjs <- sobjs_raw[match(mat_paths, names(sobjs_raw))]
names(sobjs) <- names(mat_paths)


# Annotate cell subtypes ----

# Split into DCs, LECs, FRCs and annotate subtypes
sobjs_sub <- sobjs %>%
  imap(~ {
    so_type <- so_types[[.y]]
    
    res <- .x %>%
      subset_sobj(
        cell_type   = so_type,
        type_column = "cell_type"
      ) %>%
      run_clustifyr(
        type_in    = so_type,
        ref        = subtype_refs[[so_type]],
        threshold  = subtype_thresh,
        resolution = subtype_res[[.y]],
        prefix     = "t2"
      )
    
    res
  })


# Add subtype assignments and re-subset ----

# Additional cell types to include in objects
inc_types <- so_types %>%
  map(~ if_else(
    .x == "LEC",
    list(c("B cell", "T cell", "epithelial")),
    list(c("B cell", "T cell", "NK"))
  )) %>%
  flatten()

# Add subtype assignments back to main objects and split again to now include
# additional cell types (B cells, T cells, etc.) for plotting
sobjs <- sobjs %>%
  imap(~ {
    type_res <- sobjs_sub[[.y]] %>%
      FetchData(c(
        "t2_type", "t2_r",
        "subtype", "subtype_cluster"
      ))
    
    res <- .x %>%
      AddMetaData(type_res) %>%
      subset_sobj(
        cell_types   = c(so_types[[.y]], inc_types[[.y]]),
        type_column  = "cell_type",
        regress_vars = c("Percent_mito", "nCount_RNA", "S.Score", "G2M.Score")
      )
    
    res <- res %>%
      AddMetaData(FetchData(., c("UMAP_1", "UMAP_2")))
    
    res@meta.data <- res@meta.data %>%
      rownames_to_column("cell_id") %>%
      mutate(
        subtype         = ifelse(is.na(subtype), cell_type, subtype),
        subtype         = str_to_title_v2(subtype),
        subtype_cluster = ifelse(is.na(subtype_cluster), cell_type, subtype_cluster)
      ) %>%
      column_to_rownames("cell_id")
    
    res
  })


# Classify based on OVA signal ----

# Classify cells based on OVA signal
sobjs <- sobjs %>%
  imap(~ {
    classify_ova(
      sobj_in     = .x,
      filt_column = "cell_type",
      filt        = so_types[.y],
      data_column = "adt_ovalbumin",
      data_slot   = "counts",
      quiet       = T
    )
  })

# Split by subtype before classifying based on OVA
sobjs <- sobjs %>%
  map(~ {
    so_in <- .x
    subtypes <- unique(so_in$subtype)
    
    GMM_res <- subtypes %>%
      map(~ {
        classify_ova(
          sobj_in     = so_in,
          filt_column = "subtype",
          filt        = .x,
          data_column = "adt_ovalbumin",
          data_slot   = "counts",
          quiet       = T,
          return_sobj = F
        )
      }) %>%
      bind_rows() %>%
      mutate(type_GMM_grp_2 = str_c(subtype, "-", GMM_grp)) %>%
      select(
        cell_id,
        type_GMM_grp = GMM_grp,
        type_GMM_grp_2,
        type_mu = mu
      ) %>%
      column_to_rownames("cell_id")
    
    so_in <- so_in %>%
      AddMetaData(GMM_res)
  })


# Write new objects ----

sobjs %>%
  iwalk(~ write_rds(.x, path = so_out[.y]))


# ggplot2 themes
theme_info <- theme_cowplot() +
  theme(
    plot.title       = element_text(face = "plain", size = 20),
    strip.background = element_blank(),
    strip.text       = element_text(face = "plain")
  )

umap_theme <- theme_info +
  theme(
    axis.text  = element_blank(),
    axis.ticks = element_blank()
  )

blank_theme <- umap_theme +
  theme(
    axis.line  = element_blank(),
    axis.title = element_blank()
  )

# Legend guides
col_guide <- guide_legend(override.aes = list(size = 3.5, shape = 16))

outline_guide <- guide_legend(override.aes = list(
  size   = 3.5,
  shape  = 21,
  color  = "black",
  stroke = 0.25
))

# Okabe Ito color palettes
ito_cols <- c(
  palette_OkabeIto[1:4], "#d7301f", 
  palette_OkabeIto[5:6], "#6A51A3", 
  palette_OkabeIto[7:8], "#875C04"
)

ito_cols <- ito_cols[3:length(ito_cols)] %>%
  darken(0.4) %>%
  c(ito_cols, ., "#686868", "#000000")

# Set default palette
get_cols <- create_col_fun(ito_cols)

# OVA colors
ova_cols <- c(
  "ova high" = "#475C81",
  "ova low"  = "#B8ECFF",
  "Other"    = "#ffffff"
)

# Feature colors
feat_cols <- c(
  "#D7301F", "#0072B2",
  "#009E73", "#4C3250",
  "#E69F00", "#875C04"
)

# Cell subtype palettes
common_cols <- c(
  "Epithelial" = "#6A51A3",
  "B cell"     = "#E69F00",
  "T cell"     = "#009E73",
  "NK"         = "#6A51A3",
  "Unassigned" = "#999999"
)

so_cols <- so_types %>%
  map(
    set_type_cols,
    sobjs_in   = sobjs,
    type_key   = so_types,
    cols_in    = get_cols(),
    other_cols = common_cols
  )

so_cols$d2_LEC["Marco_LEC"]  <- so_cols$d14_LEC["Marco_LEC"]  <- "#CC79A7"
so_cols$d2_LEC["Collecting"] <- so_cols$d14_LEC["Collecting"] <- "#D7301F"
  
so_cols$d2_LEC <- c(
  so_cols$d2_LEC,
  Valve  = "#8C4651",
  Bridge = "#D55E00"
)
# Set equal x-axis scales
equalize_x <- function(gg_list_in, log_tran = T, ...) {
  
  set_lims <- function(gg_in, min_x, max_x, log_tran, ...) {
    res <- gg_in +
      coord_cartesian(xlim = c(min_x, max_x))
    
    if (log_tran) {
      res <- res +
        scale_x_log10(labels = trans_format("log10", math_format(10^.x)), ...)
    }
    
    res
  }
  
  gg_ranges <- gg_list_in %>%
    map(~ ggplot_build(.x)$layout$panel_scales_x[[1]]$range$range)
  
  min_val <- gg_ranges %>%
    map_dbl(~ .x[1]) %>%
    min()
  
  max_val <- gg_ranges %>%
    map_dbl(~ .x[2]) %>%
    max()
  
  res <- gg_list_in %>%
    map(
      set_lims,
      min_x    = min_val, 
      max_x    = max_val, 
      log_tran = log_tran,
      breaks   = c(0.01, 1, 100, 10000),
      ...
    )
  
  res
}

# Create figure 3 panels
create_fig3 <- function(sobj_in, cols_in, subtype_column = "subtype", data_slot = "counts", ova_cols = c("#fafafa", "#d7301f"),
                        box_counts = c("Relative ova signal" = "ova_fc"), umap_counts = c("ova counts" = "adt_ovalbumin"), 
                        plot_title = NULL, pt_size = 0.1, pt_outline = 0.4, pt_size_2 = 0.3, pt_outline_2 = 0.5, box_cell_count = T,
                        control_types = c("B cell", "T cell"), on_top = NULL, ...) {
  
  box_column <- umap_column <- "cell_type"
  box_cols <- umap_cols <- cols_in
  
  # Fetch plotting data
  data_df <- sobj_in %>%
    FetchData(c(subtype_column, box_counts, umap_counts, "UMAP_1", "UMAP_2"), slot = data_slot) %>%
    as_tibble(rownames = "cell_id")
  
  if (!is.null(names(box_counts))) {
    data_df <- data_df %>%
      rename(!!box_counts)
    
    box_counts <- names(box_counts)
  }
  
  if (!is.null(names(umap_counts))) {
    data_df <- data_df %>%
      rename(!!umap_counts)
    
    umap_counts <- names(umap_counts)
  }
  
  # Set subtype order
  # Move select cell types to front of order
  data_df <- data_df %>%
    mutate(
      cell_type = !!sym(subtype_column),
      cell_type = fct_reorder(cell_type, !!sym(box_counts), median)
    )
  
  type_order <- levels(data_df$cell_type)
  
  if (!is.null(control_types)) {
    control_types <- control_types[control_types %in% type_order]
    type_order <- type_order[!type_order %in% control_types]
    type_order <- c(control_types, type_order)
  }
  
  # Count cells for each subtype
  data_df <- data_df %>%
    group_by(cell_type) %>%
    mutate(cell_count = n_distinct(cell_id)) %>%
    ungroup() %>%
    mutate(cell_type = fct_relevel(cell_type, type_order)) %>%
    arrange(cell_type) %>%
    mutate(
      cell_count = str_c(cell_type, "\n(n = ", cell_count, ")"),
      cell_count = fct_inorder(cell_count)
    )
  
  # Set cell type colors
  names(type_order) <- levels(data_df$cell_count)
  
  cols_df <- tibble(
    cell_type = type_order,
    cell_count = names(type_order)
  )
  
  cols_df <- cols_df %>%
    mutate(color = cols_in[cell_type])
  
  # Subtype UMAP
  if (!is.null(on_top)) {
    on_top <- umap_cols[on_top]
    
    umap_cols <- umap_cols[!names(umap_cols) %in% names(on_top)]
    umap_cols <- c(umap_cols, on_top)
  }
  
  umap <- data_df %>%
    plot_features(
      feature     = umap_column,
      pt_size     = pt_size,
      pt_outline  = pt_outline,
      plot_cols   = umap_cols,
      feat_levels = names(umap_cols)
    ) +
    guides(color = guide_legend(override.aes = list(size = 3.5))) +
    ggtitle(plot_title) +
    blank_theme +
    theme(
      plot.title = element_text(size = 12),
      legend.position = "none"
    ) +
    theme(...)
  
  # OVA UMAP
  if (!is.null(umap_counts)) {
    ova_umap <- data_df %>%
      plot_features(
        feature    = umap_counts,
        plot_cols  = ova_cols,
        pt_size    = pt_size_2,
        pt_outline = pt_outline_2,
        min_pct    = 0.01,
        max_pct    = 0.99
      ) +
      guides(color = guide_colorbar(frame.colour = "black", frame.linewidth = 0.2)) +
      blank_theme +
      theme(
        plot.title        = element_text(size = 10, hjust = 0.5),
        legend.position   = "right",
        legend.key.width  = unit(0.15, "cm"),
        legend.key.height = unit(0.30, "cm"),
        legend.title      = element_text(size = 10),
        legend.text       = element_text(size = 8)
      )
  }
  
  # OVA boxes
  if (box_cell_count) {
    box_column <- "cell_count"
    box_cols <- set_names(
      x  = cols_df$color,
      nm = cols_df$cell_count
    )
  }
  
  boxes <- data_df %>%
    ggplot(aes(!!sym(box_counts), !!sym(box_column), fill = !!sym(box_column))) +
    geom_violin(size = 0.3, draw_quantiles = c(0.25, 0.75), alpha = 0.75) +
    stat_summary(geom = "point", color = "black", fun = median) +
    scale_color_manual(values = box_cols) +
    scale_fill_manual(values = box_cols) +
    theme_minimal_vgrid() +
    theme(
      legend.position    = "none",
      axis.title.y       = element_blank(),
      axis.title         = element_text(size = 10),
      axis.text          = element_text(size = 8),
      axis.ticks.x       = element_line(size = 0.1),
      panel.grid.major.x = element_line(size = 0.1)
    )
  
  res <- list(umap, boxes)
  
  if (!is.null(umap_counts)) {
    res <- append(res, list(ova_umap))
  }
  
  names(res) <- rep(plot_title, length(res))
  
  res
}

# Create figure 4 panels
create_fig4 <- function(sobj_in, gmm_column = "GMM_grp", ova_cols, feats, feat_cols, ref_cols, pt_size = 0.00001, pt_outline = 0.4,
                        show_bars = T, low_col = c("white", "white"), sep_bar_labs = F, plot_boxes = T, median_pt = 1, on_top = NULL) {

  # Theme elements
  legd_guide <- guide_legend(override.aes = list(
    size   = 3.5,
    shape  = 21,
    color  = "black",
    stroke = 0.25
  ))
  
  text_theme <- theme(
    axis.title  = element_text(size = 10),
    legend.text = element_text(size = 8),
    axis.text   = element_text(size = 8)
  )
  
  # Data for OVA group UMAP
  ova_order <- names(ova_cols)
  
  data_df <- sobj_in@meta.data %>%
    as_tibble(rownames = "cell_id") %>%
    group_by(!!sym(gmm_column)) %>%
    mutate(cell_count = n_distinct(cell_id)) %>%
    ungroup() %>%
    mutate(!!sym(gmm_column) := fct_relevel(!!sym(gmm_column), ova_order)) %>%
    arrange(!!sym(gmm_column)) %>%
    mutate(
      cell_count = str_c(!!sym(gmm_column), "\n(n = ", cell_count, ")"),
      cell_count = fct_inorder(cell_count)
    )
  
  # Set OVA group colors
  names(ova_order) <- levels(data_df$cell_count)
  
  cols_df <- tibble(
    grp        = ova_order,
    cell_count = names(ova_order)
  )
  
  cols_df <- cols_df %>%
    mutate(color = ova_cols[grp])
  
  umap_cols <- set_names(
    x  = cols_df$color,
    nm = cols_df$cell_count
  )

  ova_guide <- legd_guide
  ova_guide$reverse <- T
  
  # Create OVA group UMAP
  ova_grp_umap <- data_df %>%
    plot_features(
      feature     = "cell_count",
      pt_size     = pt_size,
      pt_outline  = pt_outline,
      plot_cols   = umap_cols,
      feat_levels = names(ova_order)
    ) +
    guides(color = ova_guide, fill = ova_guide) +
    text_theme +
    blank_theme +
    theme(
      plot.margin     = unit(c(0.2, 1, 0.2, 1.5), "cm"),
      legend.position = "top",
      legend.title    = element_blank(),
      legend.text     = element_text(size = 8)
    )
  
  # OVA hist
  ova_hist_order <- names(ova_cols) %>%
    grep("ova (low|high)", ., value = T)

  ova_hist <- sobj_in@meta.data %>%
    filter(!!sym(gmm_column) != "Other") %>%
    
    group_by(!!sym(gmm_column)) %>%
    mutate(mu = mean(adt_ovalbumin)) %>%
    ungroup() %>%
    
    mutate(!!sym(gmm_column) := fct_relevel(!!sym(gmm_column), ova_hist_order)) %>%
    ggplot(aes(adt_ovalbumin + 1, after_stat(density), fill = !!sym(gmm_column))) +
    
    stat_density(geom = "point", position = "identity", size = 0, color = "white") +
    geom_density(fill = "white", color = "white", size = 0.3, alpha = 0.9) +
    geom_density(size = 0.3, alpha = 0.8, show.legend = F) +
  
    geom_vline(aes(xintercept = mu), size = 0.5, linetype = 2, color = "grey35") +
    coord_cartesian(ylim = c(0, 1.7)) +
    scale_fill_manual(values = ova_cols) +
    scale_x_log10(labels = trans_format("log10", math_format(10^.x))) +
    labs(x = "ova counts", y = "Density") +
    guides(fill = guide_legend(override.aes = list(shape = 22, size = 3.5, stroke = 0.25, color = "black"))) +
    theme_minimal_hgrid() +
    text_theme +
    theme(
      plot.margin        = unit(c(0.2, 0.5, 0.2, 0.2), "cm"),
      legend.position    = c(0.05, 0.92),
      legend.key.height  = unit(0.15, "cm"),
      legend.title       = element_blank(),
      axis.line.y        = element_line(size = 0.5, color = "grey85"),
      axis.ticks.y       = element_line(size = 0.1),
      panel.grid.major.y = element_line(size = 0.1)
    )
  
  # OVA subtype bar graphs
  type_bar <- sobj_in@meta.data %>%
    as_tibble(rownames = "cell_id") %>%
    filter(!!sym(gmm_column) != "Other") %>%
    mutate(
      subtype = fct_reorder(subtype, cell_id, n_distinct),
      !!sym(gmm_column) := fct_relevel(!!sym(gmm_column), c("ova low", "ova high"))
    ) %>%
    ggplot(aes(!!sym(gmm_column), fill = subtype)) +
    
    stat_count(position = "fill", geom = "point", size = 0, color = "white") +
    geom_bar(position = "fill", size = 0.25, color = "black", show.legend = F) +
    
    scale_fill_manual(values = ref_cols) +
    scale_y_continuous(breaks = c(0, 0.5, 1)) +
    guides(fill = guide_legend(ncol = 1, override.aes = list(shape = 22, size = 3.5, stroke = 0.25, color = "black"))) +
    labs(y = "Fraction of cells") +
    theme_minimal_hgrid() +
    text_theme +
    theme(
      legend.title       = element_blank(),
      legend.key.height  = unit(0.35, "cm"),
      axis.title.x       = element_blank(),
      axis.line.y        = element_line(size = 0.5, color = "grey85"),
      axis.ticks.y       = element_line(size = 0.1),
      axis.text.x        = element_text(hjust = c(0.6, 0.4)),
      panel.grid.major.y = element_blank()
    )
  
  if (sep_bar_labs) {
    type_bar <- type_bar +
      theme(axis.text.x = element_text(hjust = c(0.8, 0.2)))
  }
  
  # Reference UMAP
  if (!is.null(on_top)) {
    on_top <- ref_cols[on_top]
    
    ref_cols <- ref_cols[!names(ref_cols) %in% names(on_top)]
    ref_cols <- c(ref_cols, on_top)
  }
  
  ref_umap <- sobj_in %>%
    create_ref_umap(
      feature     = "subtype",
      color_guide = legd_guide,
      plot_cols   = ref_cols,
      pt_mtplyr   = pt_size / 0.1,
      pt_outline  = pt_outline,
      feat_levels = names(ref_cols)
    ) +
    theme(
      plot.margin     = unit(c(0.2, 1.5, 0.2, 0.2), "cm"),
      legend.position = "left",
      legend.margin   = margin(0.2, 0.2, 0.2, 1.5, "cm"),
      legend.text     = element_text(size = 8)
    )
  
  # Create list of feature UMAPs
  feat_umaps <- sobj_in %>%
    create_marker_umaps(
      pt_mtplyr     = pt_size / 0.1,
      add_outline   = pt_outline,
      input_markers = feat_cols,
      low_col       = low_col
    ) %>%
    map(~ {
      .x +
        guides(color = guide_colorbar(frame.colour = "black", frame.linewidth = 0.2)) +
        theme(plot.title = element_text(size = 12))
    })
  
  # Top panel of feature UMAPs
  top_umaps <- append(list(ref_umap), feat_umaps[1:2]) %>%
    plot_grid(
      plotlist   = .,
      rel_widths = c(1, 0.5, 0.5),
      nrow       = 1
    )
  
  # Bottom panel of feature UMAPs
  bot_umaps <- feat_umaps[3:6] %>%
    plot_grid(
      plotlist = .,
      nrow     = 1,
      align    = "h",
      axis     = "tb"
    )
  
  # Final feature UMAP figure
  feat_umaps <- plot_grid(
    top_umaps, bot_umaps,
    ncol  = 1,
    align = "vh",
    axis  = "trbl"
  ) + 
    theme(plot.margin = unit(c(1, 0.2, 1.5, 0.2), "cm"))
  
  # Feature boxplots
  box_data <- sobj_in %>%
    FetchData(c(feats, "subtype", gmm_column)) %>%
    as_tibble(rownames = "cell_id") %>%
    pivot_longer(cols = c(-cell_id, -subtype, -!!sym(gmm_column))) %>%
    mutate(type_name = str_c(subtype, "_", name)) %>%
    group_by(type_name) %>%
    mutate(up_qt = boxplot.stats(value)$stats[4]) %>%
    ungroup() %>%
    mutate(
      name = fct_relevel(name, feats),
      type_name = fct_reorder(type_name, up_qt, median, .desc = T)
    )
    
    # TO ORDER GENES BY HIGHEST MEDIAN
    # group_by(subtype, name) %>%
    # mutate(med = median(value)) %>%
    # ungroup() %>%
    # mutate(name = fct_reorder(name, med, max, .desc = T))
    
  feat_boxes <- box_data %>%
    ggplot(aes(type_name, value, fill = subtype)) +
    facet_wrap(~ name, nrow = 1, scales = "free_x") +
    scale_fill_manual(values = ref_cols) +
    scale_color_manual(values = ref_cols) +
    labs(y = "Counts") +
    theme_minimal_hgrid() +
    text_theme +
    theme(
      strip.text         = element_text(size = 12),
      legend.title       = element_blank(),
      legend.key.height  = unit(0.35, "cm"),
      axis.title.x       = element_blank(),
      axis.text.x        = element_blank(),
      axis.line.x        = element_blank(),
      axis.ticks.x       = element_blank(),
      axis.line.y        = element_line(size = 0.5, color = "grey85"),
      panel.grid.major.y = element_blank(),
      panel.background   = element_rect(fill = "#fafafa")
    )
  
  feat_boxes <- feat_boxes +
    stat_summary(geom = "point", shape = 22, fun = median, size = 0) +
    stat_summary(geom = "point", shape = 22, fun = median, size = median_pt + 1, color = "black") +
    geom_boxplot(
      color          = "white",
      fill           = "white",
      alpha          = 1,
      size           = 0.3,
      width          = 0.6,
      outlier.colour = "white",
      outlier.alpha  = 1,
      outlier.size   = 0.1,
      coef           = 0
    ) +
    geom_boxplot(
      size           = 0.3,
      width          = 0.6,
      outlier.colour = "grey85",
      outlier.alpha  = 1,
      outlier.size   = 0.1,
      show.legend    = F,
      coef           = 0,
      fatten = 0
    ) +
    stat_summary(
      aes(color = subtype),
      geom        = "point",
      shape       = 22,
      fun         = median,
      size        = median_pt,
      stroke      = 1,
      fill        = "white",
      show.legend = F
    ) +
    guides(fill = guide_legend(override.aes = list(size = 3.5, stroke = 0.25)))
  
  # Final top panel
  blank_gg <- ggplot() +
    theme_void()
  
  top_lets <- c(letters[1:3], "")
  bot_lets <- c("", "d", "e")
  
  if (!show_bars) {
    type_bar <- blank_gg
    
    top_lets <- c(letters[1:2], "", "")
    bot_lets <- c("", "c", "d")
  }
  
  top <- plot_grid(
    ova_grp_umap, ova_hist, type_bar, blank_gg,
    rel_widths     = c(0.86, 1, 0.7, 0.2),
    labels         = top_lets,
    label_fontface = "plain",
    label_size     = 18,
    align          = "h",
    axis           = "tb",
    nrow           = 1
  )
  
  # Final bottom panel
  blank_gg <- ggplot() +
    theme_void()
  
  bot <- plot_grid(
    feat_boxes, blank_gg,
    rel_widths = c(1, if_else(plot_boxes, 0.45, 0.2))
  )
  
  # Create final figure
  res <- plot_grid(
    top, feat_umaps, feat_boxes,
    rel_heights    = c(0.55, 1.05, 0.35),
    ncol           = 1,
    labels         = bot_lets,
    label_fontface = "plain",
    label_size     = 18
  )
  
  res
}

# Create figure 4 panels
create_fig4_panels <- function(sobj_in, genes_in = NULL, ref_cols, ova_cols, feat_cols, pt_size = 0.00001, pt_outline = 0.4, 
                               low_col = c("white", "white"), sep_bar_labs = F, n_panels = 3, on_top = NULL, show_bars = F,
                               gmm_column = "GMM_grp", ova_umap_grps = c("ova low", "ova high"), ova_gmm_grps = ova_umap_grps,
                               gmm_filt = grep("ova high", ova_umap_grps, value = T)) {
  
  # gmm_column    <- "GMM_grp"
  # ova_umap_grps <- c("ova low", "ova high")
  # ova_gmm_grps  <- c("ova low", "ova high")
  
  # gmm_column    <- "type_GMM_grp_2"
  # ova_umap_grps <- c("cDC2 Tbet--ova low", "cDC2 Tbet--ova high")
  # ova_gmm_grps  <- c("cDC2 Tbet--ova low", "cDC2 Tbet--ova high")
  # ova_gmm_grps  <- NULL
  
  # Differentially expressed genes for figure
  n_genes    <- 6
  tot_genes  <- n_panels * n_genes
  
  diff_genes <- sobj_in %>%
    find_markers(
      group_column = gmm_column,
      groups_use   = ova_gmm_grps
    ) %>%
    filter(group == gmm_filt) %>%
    arrange(desc(logFC)) %>%
    pull(feature) %>%
    head(tot_genes)

  sobj_in@meta.data <- sobj_in@meta.data %>%
    rownames_to_column("cell_id") %>%
    mutate(!!sym(gmm_column) := ifelse(
      !(!!sym(gmm_column)) %in% ova_umap_grps,
      "Other",
      !!sym(gmm_column)
    )) %>%
    column_to_rownames("cell_id")
  
  feats <- 1:(length(diff_genes) / n_genes) %>%
    map(~ diff_genes[1:n_genes + ((.x - 1) * n_genes)])
  
  if (!is.null(genes_in)) {
    feats <- append(list(genes_in), feats)
  }
  
  # Create final figure
  iwalk(feats, ~ {
    umap_cols <- feat_cols[seq_along(.x)]
    umap_cols <- set_names(umap_cols, .x)
    cat("\n### v", .y, "\n", sep = "")
    
    create_fig4(
      sobj_in      = sobj_in,
      gmm_column   = gmm_column,
      ova_cols     = ova_cols,
      feats        = names(umap_cols),
      feat_cols    = umap_cols,
      ref_cols     = ref_cols,
      low_col      = low_col,
      pt_size      = pt_size,
      pt_outline   = pt_outline,
      show_bars    = show_bars,
      sep_bar_labs = sep_bar_labs,
      on_top       = on_top
    ) %>%
      print()
    
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  })
}

# Plot correlation
plot_corr <- function(sobj_in, x, y, feat, data_slot, cols_in, plot_title, ...) {
  
  res <- sobj_in %>%
    FetchData(c(x, y, feat), slot = data_slot) %>%
    mutate(
      !!sym(x) := log10(!!sym(x)),
      !!sym(y) := log10(!!sym(y))
    ) %>%
    filter(
      !!sym(x) != -Inf,
      !!sym(y) != -Inf
    ) %>%
    plot_features(
      x         = x,
      y         = y,
      feature   = feat,
      data_slot = "counts",
      plot_cols = cols_in,
      lab_pos   = c(0.9, 1),
      lab_size  = 5,
      calc_cor  = T,
      lm_line   = T,
      ...
    ) +
    ggtitle(plot_title) +
    guides(color = guide_legend(override.aes = list(size = 3.5))) +
    theme_info +
    theme(legend.title = element_blank())
  
  res
}

# Calculate pairwise p-values for gg objects
calc_p_vals <- function(gg_in, sample_name, data_column, type_column, log_tran = T) {
  
  # Pull data from gg object
  gg_data <- gg_in$data
  
  # Log transform
  if (log_tran) {
    gg_data <- gg_data %>%
      mutate(!!sym(data_column) := log10(!!sym(data_column)))
  }
  
  # Calculate median
  gg_stats <- gg_data %>%
    group_by(!!sym(type_column)) %>%
    summarize(med = median(!!sym(data_column)))
  
  # Run wilcox test
  gg_counts <- gg_data %>%
    pull(data_column)
  
  gg_groups <- gg_data %>%
    pull(type_column)
  
  res <- gg_counts %>%
    pairwise.wilcox.test(
      g = gg_groups, 
      p.adj = "bonf"
    ) %>%
    tidy()
  
  # Add medians to data.frame
  res <- gg_stats %>%
    rename(med_1 = med) %>%
    right_join(res, by = c("cell_type" = "group1"))
  
  res <- gg_stats %>%
    rename(med_2 = med) %>%
    right_join(res, by = c("cell_type" = "group2"))
  
  # Format final table
  res <- res %>%
    mutate(Sample = sample_name) %>%
    select(
      Sample,
      `Cell type 1`             = str_c(type_column, ".y"),
      `Median OVA FC 1 (log10)` = med_1,
      `Cell type 2`             = type_column,
      `Median OVA FC 2 (log10)` = med_2,
      p.value
    )
  
  res
}


subtype_order <- list(
  d2_DC   = c("CCR7hi Mig cDC2", "CCR7hi Mig cDC1", "B cell", "T cell"),
  d14_DC  = c("CCR7hi Mig cDC2", "CCR7hi Mig cDC1", "cDC2 Tbet-", "cDC2 Tbet+", "B cell", "T cell", "NK"),
  d2_LEC  = NULL,
  d14_LEC = c("Ptx3_LEC", "fLEC", "cLEC", "Collecting", "BEC", "B cell", "T cell"),
  d2_FRC  = NULL,
  d14_FRC = NULL
)

# Parameter lists
fig3_names <- c("d2_DC", "d14_DC", "d14_LEC", "d14_FRC")
fig3_sobjs <- sobjs[fig3_names]

fig3_params <- list(
  sobj_in    = fig3_sobjs, 
  plot_title = so_titles[names(fig3_sobjs)],
  cols_in    = so_cols[names(fig3_sobjs)],
  on_top     = subtype_order[names(fig3_sobjs)]
)

corr_params <- list(
  sobj_in    = sobjs, 
  plot_title = so_titles[names(sobjs)], 
  cols_in    = so_cols[names(sobjs)]
)

Figure 4

LEC markers associated with high antigen counts. (a) LECs containing low and high antigen counts were identified using a gaussian mixture model. A UMAP projection is shown for ova-low and ova-high cells. T cells, B cells, and epithelial cells are shown in white (Other). (b) The distribution of ova antigen counts is shown for ova-low and ova-high cells. Dotted lines indicate the mean counts for each population. (c) The fraction of cells belonging to each LEC subtype is shown for ova-low and ova-high populations. (d) Gene expression counts were plotted for select genes associated with ova-high LECs. (e) The expression of ova-high LEC markers is shown for each LEC subtype.

v1




v2




v3